# -*- coding: utf-8 -*-
"""
Created on Sat Oct 8 11:40:09 2022
Last edited on Thu Nov 28 2024

@author: Andrei Sontag
"""


# Import required modules
import numpy as np
import pandas as pd
import seaborn as sn
import os, fnmatch
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns

N_vals = [12,13,16,18,19,20,22,25,28,31,33]
#N_vals = [28]

# define bins for histogram
bins_file = {'12_1': 11, '13_1':9, '16_1':11, '18_1': 12, '19_1': 13, '19_2': 13,
             '20_1': 14, '22_1': 15, '25_1': 17, '28_1': 19, '31_1': 21, '33_1': 22}
dens_limit = {'12_1': 0.12, '13_1': 0.067, '16_1': 0.06, '18_1': 0.04, '19_1': 0.05, '19_2': 0.06, 
              '20_1': 0.05, '22_1': 0.042, '25_1': 0.07, '28_1': 0.09, '31_1': 0.075, '33_1': 0.04}

for nit in range(0,len(N_vals)):
    
    N = N_vals[nit]
    
    os.chdir(r'./')
    
    # for each group size, find the data files for that group
    fileOfDirectory = os.listdir('.')
    pattern = str(N)+r'*.csv'
    files = []
    for filename in fileOfDirectory:
            if fnmatch.fnmatch(filename, pattern):
                    files.append(filename)
    
    # sort by time
    files.sort(key=lambda x: os.path.getmtime(x),reverse=True)
    
    for file in files:
        print(file)
        
        # read file
        df = pd.read_csv(file)
        
        rounds=120
    
        # obtain the number of A votes and B votes in the group at each round
        strv = []
        for k in range(1,rounds+1):
                strv = np.append(strv,r"my_voting.{0:.0f}.group.A_votes".format(k))
    
        strvb = []
        for k in range(1,rounds+1):
                strvb = np.append(strvb,r"my_voting.{0:.0f}.group.B_votes".format(k))
        
        # convert to numpy array
        avotes = df[strv].to_numpy()[0]
        bvotes = df[strvb].to_numpy()[0]
        
        votes = np.array([avotes,bvotes])
        
        # put the data from votes into a more convenient format
        data = np.transpose(votes)
        
        # define the counts for each (A,B) state in the experimental data
        count = np.zeros((N+1,N+1))
        for k in np.arange(0,len(data[:,0])):
                count[int(data[k,1]),int(data[k,0])] += 1
        
        # import simulation data
        data2 = np.loadtxt(r".\simulations\simdata_"+file[:-4]+".txt", delimiter=',')
        
        # define counts for each (A,B) state in the simulated data
        count2 = np.zeros((N+1,N+1))
        for k in np.arange(0,len(data2[:,1])):
                count2[int(data2[k,1]),int(data2[k,0])] += 1
       
        stx = np.random.uniform(0,1)
        sty = np.random.uniform(0,1-stx)
        
        # Figures
        # Create layout
        layout = [
            ["A", "B"],
            ["A", "B"],
            ["A", "B"],
            ["D", "E"],
            ["D", "E"],
            ["G", "H"],
            ["G", "H"]
        ]
        
        matplotlib.rcParams.update({'font.size': 15})
    
        fig, axd = plt.subplot_mosaic(layout, figsize=(8,9))
        cmap = 'Blues'
        
        # 1D arrays
        x = np.linspace(0,1,101)
        y = np.linspace(0,1,101)
        
        # plot the heatmap (square for better contrast in colours)
        ax = sn.heatmap((count+np.transpose(count))**2,cmap=cmap,cbar=False,ax=axd['A'])
        axd['A'].invert_yaxis()
        axd['A'].set_xlabel('X votes')
        axd['A'].set_ylabel('Y votes')
        
        # cover the upper triangle
        points = [[0, N+1], [N+1, N+1], [N+1, 0]]
        triangle = np.array(points+ points[:1])
        
        xax = np.array([0,N+1])
        
        axd['A'].fill(triangle[:, 0], triangle[:, 1], color='white', alpha=1)   
        sn.lineplot(x=xax, y=N+1-xax,linestyle='-',color='black',linewidth=1,ax=axd['A'])
        sn.lineplot(x=xax, y=2*[0.1],linestyle='-',color='black',ax=axd['A'])
        axd['A'].plot(2*[0.1], xax,linestyle='-',color='black')
        
        # import fitted rates
        pat = r'.\fitted_rates\rates_'+file[:-4]+'.txt'
        params = np.loadtxt(pat,delimiter=',')
        
        arl = params[0];
        brl = params[1];
        crl = params[2];
        ars = params[3];
        brs = params[4];
        crs = params[5];
        asr = params[6];
        bsr = params[7];
        csr = params[8];
        
        # define coefficients of the deterministic equation
        A = -arl-ars
        B = -crs
        C = -crl-brs+bsr
        D = arl
        E = asr
        F = crl+csr
        
        # Meshgrid
        X,Y = np.meshgrid(x,y)
          
        # Assign vector directions
        Ex = (A*X + B*X*Y + C*X*(1-X-Y) + D*Y + E*(1-X-Y) + F*Y*(1-X-Y))*(1-X-Y > 0)
        Ey = (D*X + B*X*Y + F*X*(1-X-Y) + A*Y + E*(1-X-Y) + C*Y*(1-X-Y))*(1-X-Y > 0)
        
        # plot the histogram of the simulated data (square for better colour contrast) and flow field over the heatmap
        axd['B'].streamplot(X*(N+1),Y*(N+1),Ex*(N+1),Ey*(N+1), density=1, linewidth=1, color='grey') #color='#A23BEC'
        ax2 = sn.heatmap((count2+np.transpose(count2))**2,cmap=cmap,cbar=False,ax=axd['B'])
        axd['B'].invert_yaxis()
        axd['B'].set_xlabel('X votes')
        axd['B'].set_ylabel('Y votes')
        axd['B'].fill(triangle[:, 0], triangle[:, 1], color='white', alpha=1)   
        sn.lineplot(x=xax, y=N+1-xax,linestyle='-',color='black',linewidth=1,ax=axd['B'])
        sn.lineplot(x=xax, y=2*[0.1],linestyle='-',color='black',ax=axd['B'])
        axd['B'].plot(2*[0.1], xax,linestyle='-',color='black')
        
        nbins = bins_file[file[:-4]]
        width = (2*N+1)/nbins
        bins = np.arange(0,N+width/2,width)+width/2
        bins = np.append(-bins[::-1],bins)
        
        # plot symmetrised histogram
        sns.histplot(np.append(avotes-bvotes,bvotes-avotes), kde = True, kde_kws= {'bw_adjust' : 1}, bins = bins,color='steelblue',ax=axd['D'],stat="density")
        axd['D'].set_xlabel('z')
        axd['D'].set(ylim=(0, dens_limit[file[:-4]]))
        axd['D'].set_xticks([-10, 10])
        axd['D'].set_xticklabels([-10,10])
        
        # time series of the simulated data
        sns.histplot(np.append(data2[:,0]-data2[:,1],data2[:,1]-data2[:,0]), kde = True, kde_kws= {'bw_adjust' : 2}, bins = bins,color='steelblue',ax=axd['E'],stat="density")
        axd['E'].set_xlabel(r'z')
        axd['E'].set(ylim=(0, dens_limit[file[:-4]]))
        axd['E'].set_xticks([-10, 10])
        axd['E'].set_xticklabels([-10,10])
        
        # time series of the experimental data
        rnds = np.arange(1,120)
        axd['G'].plot(rnds, avotes[1:120], color='red',label='X votes')
        axd['G'].plot(rnds, bvotes[1:120], color='steelblue',label='Y votes')
        axd['G'].plot(rnds, N-avotes[1:120]-bvotes[1:120], color='#222021',label = 'Abstentions')
        axd['G'].set_xlabel('Rounds')
        axd['G'].set_ylabel('# of votes')
        
        # time series of the simulated data
        axd['H'].plot(rnds,data2[1:120,1], color='steelblue',label='A votes')
        axd['H'].plot(rnds,data2[1:120,0], color='red',label='B votes')
        axd['H'].plot(rnds,N-data2[1:120,0]-data2[1:120,1], color='#222021',label='Abstentions')
        axd['H'].set_xlabel('Rounds')
        axd['H'].set_ylabel('# of votes')
    
        plt.tight_layout()
        fig.subplots_adjust(hspace=2)
        plt.show()
        
        #plt.savefig('./fit_spd_%d_1.svg' % N, format='svg', dpi=300)